{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import videoflow\n",
    "import videoflow.core.flow as flow\n",
    "from videoflow.consumers import VideofileWriter, CommandlineConsumer\n",
    "from videoflow.producers import VideofileReader\n",
    "from videoflow_contrib.processors.detector_tf import TensorflowObjectDetector\n",
    "from videoflow_contrib.processors.tracker_sort import KalmanFilterBoundingBoxTracker\n",
    "from videoflow.processors.vision.annotators import TrackerAnnotator, BoundingBoxAnnotator\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BoundingBoxesFilter(videoflow.core.node.ProcessorNode):\n",
    "    def __init__(self, class_indexes_to_keep):\n",
    "        self._class_indexes_to_keep = class_indexes_to_keep\n",
    "        super(BoundingBoxesFilter, self).__init__()\n",
    "    \n",
    "    def process(self, dets):\n",
    "        '''\n",
    "        Keeps only the boxes with the class indexes\n",
    "        specified in self._class_indexes_to_keep\n",
    "\n",
    "        - Arguments:\n",
    "            - dets: np.array of shape (nb_boxes, 6) \\\n",
    "                Specifically (nb_boxes, [xmin, ymin, xmax, ymax, class_index, score])\n",
    "        '''\n",
    "        f = np.array([dets[:, 4] == a for a in self._class_indexes_to_keep])\n",
    "        f = np.any(f, axis = 0)\n",
    "        filtered = dets[f]\n",
    "        return filtered\n",
    "\n",
    "def run_flow():\n",
    "    \n",
    "    tensorflow_model_path = \"/Users/dearj019/Downloads/faster_rcnn_resnet101_coco_2018_01_28/frozen_inference_graph.pb\"\n",
    "    tensorflow_model_classes = 90\n",
    "    output_file = \"output.avi\"\n",
    "    output_file2 = \"output2.avi\"\n",
    "    input_file = \"/Users/dearj019/Downloads/disney2.mp4\"\n",
    "    class_labels_path = \"mscoco_labels.pbtxt\"\n",
    "    \n",
    "    reader = VideofileReader(input_file, 2)\n",
    "    detector = TensorflowObjectDetector(tensorflow_model_path, tensorflow_model_classes)(reader)\n",
    "    detector_annotator = BoundingBoxAnnotator(class_labels_path)(reader, detector)\n",
    "    tracker = KalmanFilterBoundingBoxTracker()(detector)\n",
    "    tracker_printer = CommandlineConsumer()(tracker)\n",
    "    annotator = TrackerAnnotator()(reader, tracker)\n",
    "    writer = VideofileWriter(output_file, fps = 30)(annotator)\n",
    "    writer1 = VideofileWriter(output_file2, fps = 30)(detector_annotator)\n",
    "    fl = flow.Flow([reader], [writer], flow_type = flow.BATCH)\n",
    "    fl.run()\n",
    "    fl.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2019-05-16 16:33:56,141 - videoflow.core - INFO - Started task allocation for 8 tasks\n",
      "2019-05-16 16:33:56,183 - videoflow.core - INFO - Started running flow.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4.49908754e+02 8.94678081e-07 6.28376891e+02 1.41900451e+02\n",
      "  2.00000000e+00]\n",
      " [3.17394592e+02 5.27391052e+02 5.92527100e+02 7.90220032e+02\n",
      "  1.00000000e+00]]\n"
     ]
    }
   ],
   "source": [
    "run_flow()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
